{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.9.0\n",
      "\n",
      "torch 1.7.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# Gradient Checkpointing Demo (Network-in-Network trained on CIFAR-10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "\n",
    "Why do we care about gradient checkpointing? It can lower the memory requirement of deep neural networks quite substantially, allowing us to work with larger architectures and memory limitations of conventional GPUs. However, there is no free lunch here: as a trade-off for the lower-memory requirements, additional computations are carried out which can prolong the training time. However, when GPU-memory is a limiting factor that we cannot even circumvent by lowering the batch sizes, then gradient checkpointing is a great and easy option for making things work!\n",
    "\n",
    "Below is a brief summary of how gradient checkpointing works. For more details, please see the excellent explanations in [1] and [2].\n",
    "\n",
    "## Vanilla Backpropagation\n",
    "\n",
    "In vanilla backpropagation (the standard version of backpropagation), the required memory grows linearly with the number of layers *n* in the neural network. This is because all nodes from the forward pass are being kept in memory (until all their dependent child nodes are processed).\n",
    "\n",
    "![](figures/gradient-checkpointing-1.png)\n",
    "\n",
    "## Low-memory Backpropagation\n",
    "\n",
    "In the low-memory version of backpropagation, the forward pass is recomputed at each step, making it more memory-efficient than vanilla backpropagation, trading the memory for additional computations. In comparison, vanilla backpropagation processes *n* layers (nodes), the low-memory version processes $n^2$ nodes.\n",
    "\n",
    "![](figures/gradient-checkpointing-2.png)\n",
    "\n",
    "## Gradient Checkpointing\n",
    "\n",
    "The gradient checkpointing method is a compromise between vanilla backpropagation and low-memory backpropagation, where nodes are recomputed more often than in vanilla backpropagation but not as often as in the low-memory version. In gradient checkpointing, we designate certain nodes as checkpoints so that they are not recomputed and serve as a basis for recomputing other nodes. The optimal choice is to designate every `\\sqrt{n}`-th node as a checkpoint node. Consequently, the memory requirement increases by `\\sqrt{n}` compared to the low-memory version of backpropagation.\n",
    "\n",
    "As stated in [3], gradient checkpointing, we can implement models that are 4x to 10x larger than architectures that would usually fit into GPU memory.\n",
    "\n",
    "![](figures/gradient-checkpointing-3.png)\n",
    "\n",
    "## Gradient Checkpointing in PyTorch\n",
    "\n",
    "PyTorch allows us to use gradient checkpointing very conveniently. In this notebook, we are only using the checkpointing for sequential models. However, it is also possible (and not much more complicated) to use checkpointing for non-sequential models. I recommend checking out the tutorial in [3] for more details.\n",
    "\n",
    "A great performance benchmark and write-up is available at [4], showing the difference in memory consumption between a baseline ResNet-18 and one enhanced with gradient checkpointing.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "References\n",
    "-----------\n",
    "\n",
    "\n",
    "[1] Saving memory using gradient-checkpointing: https://github.com/cybertronai/gradient-checkpointing\n",
    "\n",
    "[2] Fitting larger networks into memory: https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9\n",
    "\n",
    "[3] Trading compute for memory in PyTorch models using Checkpointing: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb\n",
    "\n",
    "[4] Deep Learning Memory Usage and Pytorch Optimization Tricks: https://www.sicara.ai/blog/2019-28-10-deep-learning-memory-usage-and-pytorch-optimization-tricks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Network Architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this demo, I am using a simple Network-in-Network (NiN) architecture for the purpose of code readability. The gain from gradient checkpointing can be larger the deeper the architecture. \n",
    "\n",
    "\n",
    "The CNN architecture is based on\n",
    "- Lin, Min, Qiang Chen, and Shuicheng Yan. \"Network in network.\" arXiv preprint arXiv:1312.4400 (2013).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 1: Setup and Baseline (No Gradient Checkpointing)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.dataset import Subset\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model Settings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setting a random seed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I recommend using a function like the following one prior to using dataset loaders and initializing a model if you want to ensure the data is shuffled in the same manner if you rerun this notebook and the model gets the same initial random weights:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_all_seeds(seed):\n",
    "    os.environ[\"PL_GLOBAL_SEED\"] = str(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Setting cuDNN and PyTorch algorithmic behavior to deterministic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similar to the `set_all_seeds` function above, I recommend setting the behavior of PyTorch and cuDNN to deterministic (this is particulary relevant when using GPUs). We can also define a function for that:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_deterministic():\n",
    "    if torch.cuda.is_available():\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "    torch.set_deterministic(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda:2\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "CUDA_DEVICE_NUM = 2 # change as appropriate\n",
    "DEVICE = torch.device('cuda:%d' % CUDA_DEVICE_NUM if torch.cuda.is_available() else 'cpu')\n",
    "print('Device:', DEVICE)\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.0001\n",
    "BATCH_SIZE = 256\n",
    "NUM_EPOCHS = 40\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "set_all_seeds(RANDOM_SEED)\n",
    "\n",
    "# Deterministic behavior not yet supported by AdaptiveAvgPool2d\n",
    "#set_deterministic()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Import utility functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, \"..\") # to include ../helper_evaluate.py etc.\n",
    "\n",
    "from helper_evaluate import compute_accuracy\n",
    "from helper_data import get_dataloaders_cifar10\n",
    "from helper_train import train_classifier_simple_v1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "### Set random seed ###\n",
    "set_all_seeds(RANDOM_SEED)\n",
    "\n",
    "##########################\n",
    "### Dataset\n",
    "##########################\n",
    "\n",
    "train_loader, valid_loader, test_loader = get_dataloaders_cifar10(\n",
    "    batch_size=BATCH_SIZE, \n",
    "    num_workers=2, \n",
    "    validation_fraction=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Set:\n",
      "\n",
      "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([256])\n",
      "tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3])\n",
      "\n",
      "Validation Set:\n",
      "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([256])\n",
      "tensor([7, 1, 4, 1, 0, 2, 2, 5, 9, 6])\n",
      "\n",
      "Testing Set:\n",
      "Image batch dimensions: torch.Size([256, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([256])\n",
      "tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3])\n"
     ]
    }
   ],
   "source": [
    "# Checking the dataset\n",
    "print('Training Set:\\n')\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.size())\n",
    "    print('Image label dimensions:', labels.size())\n",
    "    print(labels[:10])\n",
    "    break\n",
    "    \n",
    "# Checking the dataset\n",
    "print('\\nValidation Set:')\n",
    "for images, labels in valid_loader:  \n",
    "    print('Image batch dimensions:', images.size())\n",
    "    print('Image label dimensions:', labels.size())\n",
    "    print(labels[:10])\n",
    "    break\n",
    "\n",
    "# Checking the dataset\n",
    "print('\\nTesting Set:')\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.size())\n",
    "    print('Image label dimensions:', labels.size())\n",
    "    print(labels[:10])\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the basic NiN model without gradient checkpointing for reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "class NiN(nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super(NiN, self).__init__()\n",
    "        self.num_classes = num_classes\n",
    "        self.classifier = nn.Sequential(\n",
    "                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n",
    "\n",
    "                )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.classifier(x)\n",
    "        logits = x.view(x.size(0), self.num_classes)\n",
    "        #probas = torch.softmax(logits, dim=1)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "set_all_seeds(RANDOM_SEED)\n",
    "\n",
    "model = NiN(NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/002 | Batch 0000/0176 | Loss: 2.3045\n",
      "Epoch: 001/002 | Batch 0050/0176 | Loss: 2.2849\n",
      "Epoch: 001/002 | Batch 0100/0176 | Loss: 2.1435\n",
      "Epoch: 001/002 | Batch 0150/0176 | Loss: 2.0891\n",
      "***Epoch: 001/002 | Train. Acc.: 20.751% | Loss: 2.119\n",
      "***Epoch: 001/002 | Valid. Acc.: 20.600% | Loss: 2.121\n",
      "Time elapsed: 0.40 min\n",
      "Epoch: 002/002 | Batch 0000/0176 | Loss: 2.1154\n",
      "Epoch: 002/002 | Batch 0050/0176 | Loss: 2.0218\n",
      "Epoch: 002/002 | Batch 0100/0176 | Loss: 2.0404\n",
      "Epoch: 002/002 | Batch 0150/0176 | Loss: 1.9474\n",
      "***Epoch: 002/002 | Train. Acc.: 26.649% | Loss: 1.978\n",
      "***Epoch: 002/002 | Valid. Acc.: 25.720% | Loss: 1.989\n",
      "Time elapsed: 0.80 min\n",
      "Total Training Time: 0.80 min\n",
      "87518, 143683\n"
     ]
    }
   ],
   "source": [
    "import tracemalloc\n",
    "\n",
    "\n",
    "tracemalloc.start()\n",
    "log_dict = train_classifier_simple_v1(num_epochs=2, model=model, \n",
    "                                      optimizer=optimizer, device=DEVICE, \n",
    "                                      train_loader=train_loader, valid_loader=valid_loader,\n",
    "                                      logging_interval=50)\n",
    "\n",
    "current, peak =  tracemalloc.get_traced_memory()\n",
    "print(f\"{current}, {peak}\")\n",
    "tracemalloc.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Delete model and free memory\n",
    "\n",
    "model.cpu()\n",
    "del model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 2: Modified NiN with Gradient Checkpointing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The changes we have to make to the NiN code are highlighted below. Note that this example uses only 1 segment in `checkpoint_sequential`. Generally, a lower number of segments improves memory efficiency but making the computational performance worse since more values need to be recomputed. For this architecture, I found that `segments=1` represents a good trade-off, though."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "###### NEW ####################################################\n",
    "from torch.utils.checkpoint import checkpoint_sequential\n",
    "###############################################################\n",
    "\n",
    "\n",
    "class NiN(nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super(NiN, self).__init__()\n",
    "        self.num_classes = num_classes\n",
    "        self.classifier = nn.Sequential(\n",
    "                nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(160,  96, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=3, stride=2, padding=1),\n",
    "                nn.Dropout(0.5),\n",
    "\n",
    "                nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.Conv2d(192,  10, kernel_size=1, stride=1, padding=0),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),\n",
    "                )\n",
    "        \n",
    "        ###### NEW ####################################################\n",
    "        self.classifier_modules = [module for k, module in self.classifier._modules.items()]\n",
    "        ###############################################################\n",
    "\n",
    "    def forward(self, x):\n",
    "        \n",
    "        ###### NEW ####################################################\n",
    "        x.requires_grad = True\n",
    "        x = checkpoint_sequential(functions=self.classifier_modules, \n",
    "                                  segments=1, \n",
    "                                  input=x)\n",
    "        ###############################################################\n",
    "        \n",
    "        x = x.view(x.size(0), self.num_classes)\n",
    "        #probas = torch.softmax(x, dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_all_seeds(RANDOM_SEED)\n",
    "\n",
    "model = NiN(NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/002 | Batch 0000/0176 | Loss: 2.3045\n",
      "Epoch: 001/002 | Batch 0050/0176 | Loss: 2.2849\n",
      "Epoch: 001/002 | Batch 0100/0176 | Loss: 2.1435\n",
      "Epoch: 001/002 | Batch 0150/0176 | Loss: 2.0891\n",
      "***Epoch: 001/002 | Train. Acc.: 20.751% | Loss: 2.119\n",
      "***Epoch: 001/002 | Valid. Acc.: 20.600% | Loss: 2.121\n",
      "Time elapsed: 0.47 min\n",
      "Epoch: 002/002 | Batch 0000/0176 | Loss: 2.1154\n",
      "Epoch: 002/002 | Batch 0050/0176 | Loss: 2.0218\n",
      "Epoch: 002/002 | Batch 0100/0176 | Loss: 2.0404\n",
      "Epoch: 002/002 | Batch 0150/0176 | Loss: 1.9474\n",
      "***Epoch: 002/002 | Train. Acc.: 26.649% | Loss: 1.978\n",
      "***Epoch: 002/002 | Valid. Acc.: 25.720% | Loss: 1.989\n",
      "Time elapsed: 0.93 min\n",
      "Total Training Time: 0.93 min\n",
      "57806, 115055\n"
     ]
    }
   ],
   "source": [
    "tracemalloc.start()\n",
    "log_dict = train_classifier_simple_v1(num_epochs=2, model=model, \n",
    "                                      optimizer=optimizer, device=DEVICE, \n",
    "                                      train_loader=train_loader, valid_loader=valid_loader,\n",
    "                                      logging_interval=50)\n",
    "\n",
    "current, peak =  tracemalloc.get_traced_memory()\n",
    "print(f\"{current}, {peak}\")\n",
    "tracemalloc.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the gradient checkpointing improves peak memory efficiency by approximately 22% while the computational performance (runtime) becomes only 14% worse."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
